'''
Date: 2022-06-27 14:57:55
LastEditTime: 2022-06-27 16:17:22
'''
from PIL import Image
from pylab import *
import numpy as np
import cv2
from scipy.spatial import Delaunay
import matplotlib.pyplot as plt

def morphed_im(im1, im2, im1_pts, im2_pts, warp_func, dissolve_func):
  """生成 morphing 图像

  Args:
      im1 (ndarray): 原图像1的像素数组
      im2 (ndarray): 原图像2的像素数组
      im1_pts (ndarray): 原图像1上标注的点
      im2_pts (ndarray): 原图像2上标注的点
      warp_func (float): 向图像2变形(warping)的权重
      dissolve_func (float): 向图像2取色的权重
  
  Return:
      morphed_im(ndarray): 生成的图像
  """
  
  if (np.size(im1, 2) != 3 or np.size(im2, 2) != 3 or np.size(im1, 1) != np.size(im2, 1) or np.size(im1, 2) != np.size(im2, 2)):
    return
  
  xsz = np.size(im1, 0)
  ysz = np.size(im1, 1)
  pts_mean = (1 - warp_func) * im1_pts + warp_func * im2_pts
  tri = Delaunay(pts_mean)
  
  # 显示三角分割
  figure()
  # 图形1的三角分割
  sub1 = subplot(2, 2, 1)
  plt.imshow(im1)
  print(im1_pts)
  sub1.triplot(im1_pts[:, 0], im1_pts[:, 1], tri.simplices.copy())
  sub1.plot(im1_pts[:, 0], im1_pts[:, 1], 'o')
  
  # 图形2的三角分割
  sub2 = subplot(2, 2, 2)
  plt.imshow(im2)
  sub2.triplot(im2_pts[:, 0], im2_pts[:, 1], tri.simplices.copy())
  sub2.plot(im2_pts[:, 0], im2_pts[:, 1], 'o')
  
  # 平均三角分割的形状图形
  sub_mean = subplot(2, 2, 3)
  
  sub_mean.triplot(pts_mean[:, 0], ysz - pts_mean[:, 1], tri.simplices.copy())
  sub_mean.plot(pts_mean[:, 0], ysz - pts_mean[:, 1], 'o')

  # 关闭三角分割图片
  plt.pause(10)
  plt.close()
  
  # 获取重心坐标参数
  barycentric = zeros((xsz * ysz, 3), dtype=float32)
  tris = tri.simplices
  xy_mesh = np.array(list((x,y) for x in range(xsz) for y in range(ysz))) # 生成网格
  
  # 网格点对应像素点所在三角形在平均图形坐标点集中的索引
  triang = tri.find_simplex([xy_mesh])[0]
  for i in arange(0, xsz * ysz):
    xp,yp = xy_mesh[i]
    trindex = tri.find_simplex([xp, yp])
    
    if trindex == -1:
      # 当前像素点不在三角分割内时不计算入生成图形内
      barycentric[i, :] = [float('nan'), float('nan'), float('nan')]
      continue
   
    ax = pts_mean[tris[trindex, 0], 0]
    bx = pts_mean[tris[trindex, 1], 0]
    cx = pts_mean[tris[trindex, 2], 0]
    ay = pts_mean[tris[trindex, 0], 1]
    by = pts_mean[tris[trindex, 1], 1]
    cy = pts_mean[tris[trindex, 2], 1]
  
    
    A = np.mat([[ax, bx, cx], [ay, by, cy], [1, 1, 1]])
    b = np.mat([[xp], [yp], [1]])
    mul = inv(A) * b
    barycentric[i] = [mul[0,0], mul[1,0], mul[2,0]]
  
  # 用重心坐标计算生成图像每个像素对应的原图像像素的坐标
  im1_crsp = zeros((xsz * ysz, 3), dtype=float32)
  for i in arange(0, xsz * ysz):
    trindex = triang[i]
    if trindex == -1:
      im1_crsp[i] = [255,255,255]
      continue
    
    ax = im1_pts[tris[trindex, 0], 0]
    bx = im1_pts[tris[trindex, 1], 0]
    cx = im1_pts[tris[trindex, 2], 0]
    ay = im1_pts[tris[trindex, 0], 1]
    by = im1_pts[tris[trindex, 1], 1]
    cy = im1_pts[tris[trindex, 2], 1]
    
    A = np.mat([[ax, bx, cx], [ay, by, cy], [1, 1, 1]])
    
    X = A * np.mat(barycentric[i]).reshape(3, 1)
    X = uint8(around(X))
    im1_crsp[i] = im1[X[1, 0], X[0, 0]] # 这个地方xy是反着的，因为不知名原因im1的像素点的索引和颜色值是y=x的直线对称的
  
  # 用重心坐标计算生成图像每个像素对应的原图像像素的坐标
  im2_crsp = zeros((xsz * ysz, 3), dtype=float32)
  for i in arange(0, xsz * ysz):
    trindex = triang[i]
    if trindex == -1:
      im2_crsp[i] = [255,255,255]
      continue
   
    ax = im2_pts[tris[trindex, 0], 0]
    bx = im2_pts[tris[trindex, 1], 0]
    cx = im2_pts[tris[trindex, 2], 0]
    ay = im2_pts[tris[trindex, 0], 1]
    by = im2_pts[tris[trindex, 1], 1]
    cy = im2_pts[tris[trindex, 2], 1]
    
    A = [[ax, bx, cx], [ay, by, cy], [1, 1, 1]]
    
    X = A * np.mat(barycentric[i]).reshape(3, 1)
    X = uint8(around(X))
    im2_crsp[i] = im2[X[1,0], X[0,0]]
    
  print("finish warping")
  
  morphed_im = zeros((xsz * ysz, 3), dtype=float32).reshape(xsz, ysz, 3)
  morphed_im1 = zeros((xsz * ysz, 3), dtype=uint8).reshape(xsz, ysz, 3)
  morphed_im2 = zeros((xsz * ysz, 3), dtype=uint8).reshape(xsz, ysz, 3)

  # 将前面计算好的平均图像的像素值填充到数组中
  for i in arange(0, xsz * ysz):
    xp,yp = xy_mesh[i]
    morphed_im1[yp, xp] = im1_crsp[i]
    morphed_im2[yp, xp] = im2_crsp[i]
  
  
  morphed_im = (1 - dissolve_func) * morphed_im1 + dissolve_func * morphed_im2
  morphed_im = uint8(around(morphed_im))
  print("finish morphing")
  figure()
  sub1 = subplot(2, 2, 1)
  plt.imshow(morphed_im1)
  sub2 = subplot(2, 2, 2)
  plt.imshow(morphed_im2)
  sub2 = subplot(2, 2, 3)
  plt.imshow(morphed_im)
  sub_mean = subplot(2, 2, 4)
  plt.imshow(morphed_im)
  plt.pause(10)
  plt.close()
  return morphed_im

if __name__ == '__main__':
  # 展示两张图片
  im1 = array(Image.open('p1.jpg'))
  im2 = array(Image.open('p2.jpg'))
  p1 = np.array(np.loadtxt(fname="p1.csv", dtype=np.float32, delimiter=","))
  p2 = np.array(np.loadtxt(fname="p2.csv", dtype=np.float32, delimiter=","))
  im = morphed_im(im1, im2, p1, p2, 0.5, 0.5)
  cv2.cvtColor(im, cv2.COLOR_BGR2RGB, im); # 转换色域
  cv2.imwrite('./res.jpg', im)